import torch

if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(device)

x = torch.rand(3, 3).to(device)
print(f"Tensor on device: {x.device}")
